-
Notifications
You must be signed in to change notification settings - Fork 10
ENH: jax_autojit
#284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
ENH: jax_autojit
#284
Conversation
""" | ||
|
||
@override | ||
def persistent_id(self, obj: object) -> Literal[0, 1, None]: # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To say that I am aggrieved by pyright and numpydoc would be a gentle understatement at this point.
""" | ||
Register upon first use instead of at import time, to avoid | ||
globally importing JAX. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aside note on design: Dask avoids this exact problem by not requiring any decorator and instead duck-type checking for uniquely named dunder methods called __dask_<...>__
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: everything in this module is a private helper.
class Deprecated(enum.Enum): | ||
"""Unique type for deprecated parameters.""" | ||
|
||
DEPRECATED = 1 | ||
|
||
|
||
DEPRECATED = Deprecated.DEPRECATED |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't find anything to help here in the standard library, and I find that odd. Did I miss anything? I was expecting something like @functools.deprecated_args("static_argnames", "static_argnums")
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is https://docs.python.org/3/library/warnings.html#warnings.deprecated that could be used in conjunction with typing.overload.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good pointer; however it requires Python >=3.13 so I can't use it 😞
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a minor nit and a note, otherwise it looks good to me as it reduces the maintenance burden of adding/updating static_argnames/argnums
arguments.
Btw, I enjoyed reading pickle_flatten
and pickle_flatten
implementations as it felt like less is more.
Thanks, @crusaderky!
class Deprecated(enum.Enum): | ||
"""Unique type for deprecated parameters.""" | ||
|
||
DEPRECATED = 1 | ||
|
||
|
||
DEPRECATED = Deprecated.DEPRECATED |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is https://docs.python.org/3/library/warnings.html#warnings.deprecated that could be used in conjunction with typing.overload.
For what it's worth, one of the main reasons we haven't implemented something like this in JAX is because it tends to negatively impact dispatch times. The output of With that background, I'd like a bit of clarification here: is this intended for use only in testing paths, or do you imagine this will be used within the dispatch path for libraries like scipy that implement the Array API? |
I think this section of the top post answers your question, Jake:
|
Co-authored-by: Pearu Peterson <pearu.peterson@gmail.com>
scipy/scipy#22909 does not cause any issues to crop up. |
Supercharge
xpx.testing.lazy_xp_function
.This is a fairly high level change; @rgommers @pearu @jakevdp I could use your feedback.
Matching SciPy PR: scipy/scipy#22909
Automatic static arguments
scipy has seen a proliferation of
lazy_xp_function(func, static_argnames=(...))
in the initial section of many test modules. This information is only useful for JAX and is quite verbose.With scipy/scipy#22686, this informational quirk that is both specific to JAX and to unit tests moves to the implementation modules.
With this PR,
lazy_xp_function
no longer accepts parametersstatic_argnames
andstatic_argnums
. Instead, all arguments that are not JAX arrays are automatically treated as static. Note that this behaviour is generally desirable in unit testing but a bad idea in production. Consider:jax_autojit
is a new internal function of array-api-extra.In the above example,
j2
requires a lot less setup to be tested effectively, but on the flip side it means that it will be re-traced for every different value of y, which likely makes it not fit for purpose in production.To clarify:
jax_autojit
is applied and removed on the fly within tests with thexp
fixture and it is never used outside of unit testing.Wrapped inputs and output
There are a few cases of scipy functions returning bespoke containers with arrays inside instead of simple tuples, namedtuples, lists, or dicts of arrays.
Two such examples are
In main, these can't be tested with
lazy_xp_function
, and as of scipy/scipy#22686 the issue also impacts documentation.This PR lifts this restriction and allows completely arbitrary objects as parameters and as return values of the functions. If these objects internally contain JAX arrays,
lazy_xp_function
will now automatically extract them, pass them through the JIT, and reassemble everything for the test function to observe the result.The rationale is that, in real life, users are unlikely to wrap the scipy functions with jax.jit directly; instead they are more likely to consume their outputs in their own functions and then wrap those with jax.jit:
Static non-hashable arguments
Non-hashable objects can now be static.
Consider:
This PR fixes it; all objects in the input and output of the function need only be hashable or pickleable.
Dask materialization raises inside a container
This also fixes a Dask-specific bug where the graph materialization raises:
The above works when f returns plain Dask array objects or tuples or lists thereof, even if the graph would otherwise be discarded and never computed, thanks to the
lazy_xp_function
machinery, but used to fail when the return value is an opaque container with Dask arrays inside. This PR fixes it.